"""Model configuration settings."""

import jax
import jax.numpy as jnp

from ml_collections.config_dict import ConfigDict

def get_config():
    config = ConfigDict({

        'embedding': {
            'fp_type': jnp.float32,
            'int_type': jnp.int32,
            'dim_node': 64,
            'dim_edge': 64,
            'num_atom_types': 64,
            'is_emb_dis': True,
            'is_emb_bond': False,
            'dis_self': 0.05,
            'cutoff': 1.0,
            'cutoff_func': 'smooth',
            'name': 'embedding',
        },

        'model': {
            'dim_feature': 128,
            'dim_node_emb': 128,
            'dim_edge_emb': 128,
            'is_edge_update': False,
            'is_coupled_interaction': True,
            'n_interaction': 3,
            'n_heads': 8,
        },

        'readout': {
            'dim_node_rep': 128,
            'dim_output': 1,
        },
    }) 
  
    return config